name: conda

on:
  workflow_dispatch: {}
  pull_request:
    paths:
      - "packaging/**"
      - ".github/workflows/conda.yml"
      - "setup.py"
      - "requirements*.txt"
  push:
    branches:
      - main
    tags:
      - "v[0-9]+*"

# this yaml file can be cleaned up using yaml anchors, but they're not supported in github actions yet
# https://github.com/actions/runner/issues/1182

env:
  # you need at least cuda 5.0 for some of the stuff compiled here.
  TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX"
  MAX_JOBS: 3  # Avoids OOMs
  XFORMERS_BUILD_TYPE: "Release"
  XFORMERS_PACKAGE_FROM: "conda-${{ github.ref_name }}"

jobs:
  build:
    strategy:
      fail-fast: false
      matrix:
        python:
          - "3.9"
          - "3.10"
        config:
          - torch_version: "2.2.0"
            torch_channel: "pytorch"
            cuda_version: "12.1.0"
            cuda_dep_runtime: ">=12.0,<13.0"
            cuda_short_version: "121"

          - torch_version: "2.2.0"
            torch_channel: "pytorch"
            cuda_version: "11.8.0"
            cuda_dep_runtime: ">=11.7,<11.9"
            cuda_short_version: "118"

    name: py${{ matrix.python }}-torch${{ matrix.config.torch_version }}-cu${{ matrix.config.cuda_version }}
    runs-on: 8-core-ubuntu  # 32GB RAM, 8 vCPUs

    container:
      image: pytorch/conda-builder:cuda${{ matrix.config.cuda_short_version }}
    env:
      # alias for the current python version
      PY: /opt/conda/bin/python

    timeout-minutes: 360
    defaults:
      run:
        shell: bash
    steps:
      - name: System info
        run: |
          pwd
          df -h
          ldd --version
          $PY --version
          echo "Memory and swap:"
          free -h
          echo
          swapon --show
      - name: Set tag = 'dev'
        run: echo "XFORMERS_CONDA_TAG=dev" >> $GITHUB_ENV
      - name: Set tag = 'main'
        if: startsWith(github.ref, 'refs/tags/v')
        run: echo "XFORMERS_CONDA_TAG=main" >> $GITHUB_ENV
      - run: echo "${XFORMERS_CONDA_TAG}"
      - name: Add H100 if nvcc 11.08+
        shell: python
        run: |
          import os
          import sys
          print(sys.version)
          cuda_short_version = "${{ matrix.config.cuda_short_version }}"
          arch_list = os.environ["TORCH_CUDA_ARCH_LIST"]
          if cuda_short_version not in ["116", "117"]:
            arch_list += " 9.0"
          with open(os.environ['GITHUB_ENV'], "r+") as fp:
            fp.write("TORCH_CUDA_ARCH_LIST=" + arch_list + "\n")
      - run: echo "${TORCH_CUDA_ARCH_LIST}"
      - name: Free up disk space
        run: |
          df -h /
          sudo rm -rf /__t/Python /__t/CodeQL /__t/go /__t/Python /__t/PyPy /__t/node
          sudo rm -rf /opt/conda/pkgs
          df -h /
      - name: Recursive checkout
        uses: actions/checkout@v3
        with:
          submodules: recursive
          path: "."
          fetch-depth: 0 # for tags
      - name: Define version
        env:
          VERSION_SOURCE: ${{ github.ref_type == 'tag' && 'tag' || 'dev'  }}
        run: |
          set -Eeuo pipefail
          git config --global --add safe.directory "*"
          $PY -m pip install packaging
          version=`$PY packaging/compute_wheel_version.py --source $VERSION_SOURCE`
          echo "BUILD_VERSION=$version" >> ${GITHUB_ENV}
          cat ${GITHUB_ENV}
      - name: Build & store/upload
        env:
          # NOTE: Ternary operator: ${{ cond && 'trueVal' || 'falseVal' }}
          STORE_PT_PACKAGE: ${{ matrix.config.torch_channel != 'pytorch' && '--store-pytorch-package' || '' }}
        run: |
          $PY packaging/build_conda.py \
          --cuda ${{ matrix.config.cuda_version }} \
          --python ${{ matrix.python }} \
          --pytorch ${{ matrix.config.torch_version }} \
          --pytorch-channel ${{ matrix.config.torch_channel }} \
          --cuda-dep-runtime "${{ matrix.config.cuda_dep_runtime }}" \
          --store $STORE_PT_PACKAGE
      - name: Upload
        if: github.repository == 'facebookresearch/xformers' && github.event_name != 'pull_request' && matrix.config.torch_channel == 'pytorch'
        env:
          ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }}
        run: |
          # This might fail, let's retry it multiple times
          for i in {1..20}
          do
              [[ $i != 1 ]] && sleep 15
              echo "Attempt $i"
              anaconda upload --user xformers --label ${XFORMERS_CONDA_TAG} packages/* && break
          done
      - uses: actions/upload-artifact@v3
        if: success() || failure()
        with:
          name: linux-py${{ matrix.python }}-torch${{ matrix.config.torch_version }}
          path: packages
# Note: it might be helpful to have additional steps that test if the built wheels actually work
